#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf


def load_config() -> dict:
    cfg_path = Path("media_project/config.json")
    if not cfg_path.exists():
        cfg_path = Path("media_project/config.example.json")
    return json.loads(cfg_path.read_text(encoding="utf-8"))


def normalize_code6(s: pd.Series) -> pd.Series:
    out = s.astype("string").str.strip()
    out = out.replace({"": pd.NA, "nan": pd.NA, "None": pd.NA})
    out = out.str.replace(r"\.0$", "", regex=True)
    out = out.str.extract(r"(\d+)", expand=False)
    out = out.str.zfill(6)
    return out


def enc_event(e: int) -> str:
    return f"m{abs(e)}" if e < 0 else f"p{e}"


@dataclass(frozen=True)
class EventStudyResult:
    policy: str
    outcome: str
    sample_note: str
    window: list[int]
    ref_event: int
    cohorts: list[int]
    cohort_weights: dict[int, float]
    table: pd.DataFrame  # columns: event_time, beta, se, ci_low, ci_high
    lead_joint_f: float | None = None
    lead_joint_p: float | None = None
    lead_joint_k: int | None = None


def _safe_int_series(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s, errors="coerce").fillna(0).astype(int)


def _pilot_year_series(df: pd.DataFrame, policy: str) -> pd.Series:
    return pd.to_numeric(df[f"{policy}_pilot_year"], errors="coerce")


def _treat_series(df: pd.DataFrame, policy: str) -> pd.Series:
    return _safe_int_series(df[f"{policy}_treat"])


def sun_abraham_event_study(
    df: pd.DataFrame,
    *,
    policy: str,
    outcome: str,
    window: list[int],
    ref_event: int = -1,
) -> EventStudyResult:
    d = df.copy()
    pilot_year = _pilot_year_series(d, policy)
    treat = _treat_series(d, policy)
    d["treat"] = treat

    cohort = np.where(d["treat"] == 1, pilot_year, 0)
    d["cohort"] = pd.Series(cohort, index=d.index).fillna(0).astype(int)
    d["event_time"] = np.where(d["treat"] == 1, d["year"] - d["cohort"], np.nan)

    cohorts = sorted(int(x) for x in d.loc[d["treat"] == 1, "cohort"].unique() if int(x) > 0)
    if not cohorts:
        raise RuntimeError(f"No treated cohorts found for policy={policy}")

    # Create cohort×event dummies.
    reg_cols: list[str] = []
    for g in cohorts:
        for e in window:
            if e == ref_event:
                continue
            name = f"D_g{g}_e{enc_event(e)}"
            col = ((d["cohort"] == g) & (d["event_time"] == e)).astype(int)
            if int(col.sum()) == 0:
                continue
            d[name] = col
            reg_cols.append(name)

    if not reg_cols:
        raise RuntimeError(f"No non-empty event dummies for policy={policy}, outcome={outcome}")

    rhs = " + ".join(reg_cols) + " + C(city_code6) + C(year)"
    model = smf.ols(f"{outcome} ~ {rhs}", data=d)
    used_idx = model.data.row_labels
    res = model.fit(cov_type="cluster", cov_kwds={"groups": d.loc[used_idx, "city_code6"]})

    # Cohort weights by treated city count.
    cohort_city_counts = {g: int(d.loc[d["cohort"] == g, "city_code6"].nunique()) for g in cohorts}
    total = sum(cohort_city_counts.values())
    cohort_weights = {g: (cohort_city_counts[g] / total if total else 0.0) for g in cohorts}

    params = res.params
    vcov = res.cov_params()

    # Pre-trend (lead) joint test: all event times <= -2 within window.
    lead_names: list[str] = []
    for e in window:
        if e >= 0 or e == ref_event:
            continue
        if e == -1:
            continue
        for g in cohorts:
            n = f"D_g{g}_e{enc_event(e)}"
            if n in params.index:
                lead_names.append(n)

    lead_joint_f = None
    lead_joint_p = None
    lead_joint_k = None
    if lead_names:
        try:
            test = res.f_test(" = 0, ".join([f"{n}" for n in lead_names]) + " = 0")
            lead_joint_f = float(getattr(test, "fvalue", np.nan))
            lead_joint_p = float(getattr(test, "pvalue", np.nan))
            lead_joint_k = int(len(lead_names))
        except Exception:
            lead_joint_f = None
            lead_joint_p = None
            lead_joint_k = int(len(lead_names))

    rows = []
    for e in window:
        if e == ref_event:
            continue
        names = [f"D_g{g}_e{enc_event(e)}" for g in cohorts if f"D_g{g}_e{enc_event(e)}" in params.index]
        if not names:
            continue

        # Only weight cohorts that actually appear at this event time.
        active_cohorts = [int(n.split("_")[1][1:]) for n in names]  # D_g{g}_e{...}
        w_raw = np.array([cohort_weights[g] for g in active_cohorts], dtype=float)
        w = w_raw / w_raw.sum() if w_raw.sum() > 0 else w_raw

        beta_vec = params[names].values
        beta = float(w @ beta_vec)
        V = vcov.loc[names, names].values
        V = (V + V.T) / 2.0
        try:
            vals, vecs = np.linalg.eigh(V)
            vals = np.clip(vals, 0.0, None)
            V = (vecs * vals) @ vecs.T
        except Exception:
            pass
        var = float(w @ V @ w)
        se = float(math.sqrt(max(var, 0.0)))
        rows.append(
            {
                "event_time": int(e),
                "beta": beta,
                "se": se,
                "ci_low": beta - 1.96 * se,
                "ci_high": beta + 1.96 * se,
                "n_cohorts": len(active_cohorts),
            }
        )

    table = pd.DataFrame(rows).sort_values("event_time").reset_index(drop=True)
    return EventStudyResult(
        policy=policy,
        outcome=outcome,
        sample_note="Sun-Abraham cohort×event study with city/year FE; SE clustered by city",
        window=window,
        ref_event=ref_event,
        cohorts=cohorts,
        cohort_weights=cohort_weights,
        table=table,
        lead_joint_f=lead_joint_f,
        lead_joint_p=lead_joint_p,
        lead_joint_k=lead_joint_k,
    )


def write_event_study_plot(result: EventStudyResult, out_path: Path, *, title: str):
    import matplotlib.pyplot as plt

    df = result.table
    if df.empty:
        return

    fig, ax = plt.subplots(figsize=(7.5, 4.2))
    ax.axhline(0, color="black", linewidth=1, alpha=0.6)
    ax.axvline(-0.5, color="gray", linewidth=1, alpha=0.4)

    x = df["event_time"].to_numpy()
    y = df["beta"].to_numpy()
    yerr = 1.96 * df["se"].to_numpy()

    ax.errorbar(x, y, yerr=yerr, fmt="o", color="#1f77b4", ecolor="#1f77b4", elinewidth=1.2, capsize=3)
    ax.set_xlabel("Event time (years, 0 = first treated year)")
    ax.set_ylabel(result.outcome)
    ax.set_title(title)
    ax.grid(True, axis="y", alpha=0.25)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(out_path, dpi=160)
    plt.close(fig)


def prepare_panel(panel_path: Path) -> pd.DataFrame:
    df = pd.read_csv(panel_path, encoding="utf-8-sig")
    df["city_code6"] = normalize_code6(df["city_code6"]).astype(str)
    df["year"] = pd.to_numeric(df["year"], errors="coerce").astype("Int64")
    df = df.dropna(subset=["city_code6", "year"]).copy()
    df["year"] = df["year"].astype(int)

    for c in ["co2_tons", "aqi_mean", "aqi_n", "days", "good_days"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    for c in ["docs_total", "docs_hit", "wechat_rate_10k", "log_wechat_hits", "log_wechat_total", "topic_intensity_year"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    df = df.dropna(subset=["co2_tons"]).copy()
    df["log_co2"] = np.log(df["co2_tons"].astype(float))
    df["good_day_share"] = df["good_days"] / df["days"].replace({0: np.nan})
    if "log_co2_per_gdp" not in df.columns and "gdp_wanyuan" in df.columns:
        df["log_co2_per_gdp"] = np.log(df["co2_tons"].astype(float)) - np.log(pd.to_numeric(df["gdp_wanyuan"], errors="coerce"))
    return df


def write_markdown_report(
    out_path: Path,
    *,
    panel_path: Path,
    summaries: list[tuple[str, EventStudyResult, Path | None]],
    regression_tables: list[tuple[str, pd.DataFrame]] | None = None,
):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    lines: list[str] = []
    lines.append("# Environment Policy Results (CO2 / AQI)\n")
    lines.append(f"- Panel: `{panel_path}`\n")
    lines.append("- Raw archives are treated as read-only and are not redistributed; all outputs here are derived files.\n")
    lines.append("\n## Key Specs\n")
    lines.append("- Fixed effects: city + year\n")
    lines.append("- Standard errors: clustered by city\n")
    lines.append("- Event study: Sun-Abraham cohort×event interactions (reference event = -1)\n")
    lines.append("- Notes: Results are descriptive TWFE/event-study estimates; mechanism regressions do not imply causal mediation.\n")

    if regression_tables:
        for title, table in regression_tables:
            lines.append(f"\n## {title}\n\n")
            lines.append(table.to_markdown(index=False))
            lines.append("\n")

    for title, result, fig_path in summaries:
        lines.append(f"\n## {title}\n")
        lines.append(f"- Policy: `{result.policy}`\n")
        lines.append(f"- Outcome: `{result.outcome}`\n")
        lines.append(f"- Cohorts: `{result.cohorts}`\n")
        lines.append(f"- Window: `{result.window}` (ref `{result.ref_event}`)\n")
        if result.lead_joint_k:
            p = "NA" if result.lead_joint_p is None else f"{result.lead_joint_p:.4f}"
            f = "NA" if result.lead_joint_f is None else f"{result.lead_joint_f:.2f}"
            lines.append(f"- Lead joint test (event_time<=-2): `p={p}`, `F={f}`, `k={result.lead_joint_k}`\n")
        if fig_path is not None:
            lines.append(f"- Figure: `{fig_path}`\n")
        lines.append("\nEvent-time coefficients (weighted across cohorts):\n\n")
        tbl = result.table.copy()
        if tbl.empty:
            lines.append("_No estimable coefficients in this window._\n")
            continue
        tbl["beta"] = tbl["beta"].map(lambda x: f"{x:.4f}")
        tbl["se"] = tbl["se"].map(lambda x: f"{x:.4f}")
        tbl["ci_low"] = tbl["ci_low"].map(lambda x: f"{x:.4f}")
        tbl["ci_high"] = tbl["ci_high"].map(lambda x: f"{x:.4f}")
        lines.append(tbl[["event_time", "beta", "se", "ci_low", "ci_high", "n_cohorts"]].to_markdown(index=False))
        lines.append("\n")

    out_path.write_text("".join(lines), encoding="utf-8")


def twfe_cluster(
    df: pd.DataFrame,
    *,
    formula: str,
    coef: str,
) -> tuple[float, float, int]:
    model = smf.ols(formula, data=df)
    used_idx = model.data.row_labels
    res = model.fit(cov_type="cluster", cov_kwds={"groups": df.loc[used_idx, "city_code6"]})
    return float(res.params[coef]), float(res.bse[coef]), int(res.nobs)


def twfe_interaction_only(
    df: pd.DataFrame,
    *,
    y: str,
    did_col: str,
    exposure_col: str,
) -> tuple[float, float, int]:
    d = df.dropna(subset=[y, did_col, exposure_col]).copy()
    d[did_col] = pd.to_numeric(d[did_col], errors="coerce").fillna(0).astype(int)
    d[exposure_col] = pd.to_numeric(d[exposure_col], errors="coerce")
    d = d.dropna(subset=[exposure_col]).copy()
    model = smf.ols(f"{y} ~ {did_col}:{exposure_col} + C(city_code6) + C(year)", data=d)
    used_idx = model.data.row_labels
    res = model.fit(cov_type="cluster", cov_kwds={"groups": d.loc[used_idx, "city_code6"]})
    coef_name = f"{did_col}:{exposure_col}"
    return float(res.params[coef_name]), float(res.bse[coef_name]), int(res.nobs)


def twfe_with_city_trends(
    df: pd.DataFrame,
    *,
    formula_rhs: str,
    coef: str,
) -> tuple[float, float, int]:
    d = df.copy()
    d["year_c"] = d["year"] - int(d["year"].min())
    model = smf.ols(f"{formula_rhs} + C(city_code6):year_c", data=d)
    used_idx = model.data.row_labels
    # Cluster-robust SE can become numerically unstable with high-dimensional FE+trends in statsmodels;
    # use HC1 here as a robustness check on the coefficient magnitude/sign.
    res = model.fit(cov_type="HC1")
    return float(res.params[coef]), float(res.bse[coef]), int(res.nobs)


def drop_municipalities(df: pd.DataFrame) -> pd.DataFrame:
    # Beijing, Tianjin, Shanghai, Chongqing
    return df[~df["city_code6"].isin(["110000", "120000", "310000", "500000"])].copy()


def add_baseline_secondary_share(df: pd.DataFrame, start_year: int = 2007, end_year: int = 2009) -> pd.DataFrame:
    if "share_secondary_pct" not in df.columns:
        return df.copy()
    d = df.copy()
    d["share_secondary_pct"] = pd.to_numeric(d["share_secondary_pct"], errors="coerce")
    base = (
        d[d["year"].between(start_year, end_year)]
        .groupby("city_code6")["share_secondary_pct"]
        .mean()
        .rename("base_secondary_share")
    )
    d = d.merge(base, on="city_code6", how="left")
    return d


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--panel", default="media_project/out/env_policy_city_year_with_wechat.csv")
    ap.add_argument("--report", default="media_project/reports/env_policy_results.md")
    ap.add_argument("--fig-dir", default="media_project/reports/assets")
    args = ap.parse_args()

    panel_path = Path(args.panel)
    if not panel_path.exists():
        raise SystemExit(
            f"Missing panel: {panel_path}. Run build_env_policy_panel.py, then build_gov_wechat_topic_index.py + aggregate + build_env_media_panel.py."
        )

    df = prepare_panel(panel_path)

    fig_dir = Path(args.fig_dir)
    report_path = Path(args.report)

    summaries: list[tuple[str, EventStudyResult, Path | None]] = []
    regression_tables: list[tuple[str, pd.DataFrame]] = []

    # CO2: main (lowcarbon), secondary (carbon trading).
    co2_window = list(range(-8, 7))
    low_co2 = sun_abraham_event_study(df, policy="lowcarbon", outcome="log_co2", window=co2_window)
    low_fig = fig_dir / "event_lowcarbon_log_co2.png"
    write_event_study_plot(low_co2, low_fig, title="Low-carbon pilot → log(CO2) (city×year)")
    summaries.append(("CO2 (Main): Low-carbon pilot", low_co2, low_fig))

    if "log_co2_per_gdp" in df.columns and df["log_co2_per_gdp"].notna().any():
        low_int = sun_abraham_event_study(df.dropna(subset=["log_co2_per_gdp"]), policy="lowcarbon", outcome="log_co2_per_gdp", window=co2_window)
        low_int_fig = fig_dir / "event_lowcarbon_log_co2_per_gdp.png"
        write_event_study_plot(low_int, low_int_fig, title="Low-carbon pilot → log(CO2/GDP) (city×year)")
        summaries.append(("CO2 Intensity (Main): Low-carbon pilot", low_int, low_int_fig))

    ct_co2 = sun_abraham_event_study(df, policy="carbon_trading", outcome="log_co2", window=list(range(-6, 7)))
    ct_fig = fig_dir / "event_carbon_trading_log_co2.png"
    write_event_study_plot(ct_co2, ct_fig, title="Carbon-trading pilot → log(CO2) (city×year)")
    summaries.append(("CO2 (Robustness): Carbon-trading pilot", ct_co2, ct_fig))

    if "log_co2_per_gdp" in df.columns and df["log_co2_per_gdp"].notna().any():
        ct_int = sun_abraham_event_study(
            df.dropna(subset=["log_co2_per_gdp"]), policy="carbon_trading", outcome="log_co2_per_gdp", window=list(range(-6, 7))
        )
        ct_int_fig = fig_dir / "event_carbon_trading_log_co2_per_gdp.png"
        write_event_study_plot(ct_int, ct_int_fig, title="Carbon-trading pilot → log(CO2/GDP) (city×year)")
        summaries.append(("CO2 Intensity (Robustness): Carbon-trading pilot", ct_int, ct_int_fig))

    # Air quality (consistent across years): share of "good" days (优/良) at city-year level.
    if "good_day_share" in df.columns and df["good_day_share"].notna().any():
        df_gd = df.dropna(subset=["good_day_share", "days"]).copy()
        df_gd["days"] = pd.to_numeric(df_gd["days"], errors="coerce")
        df_gd = df_gd.dropna(subset=["days"]).copy()
        df_gd = df_gd[df_gd["days"] >= 330].copy()

        aq_rows = []
        # TWFE DID
        for pol in ["lowcarbon", "carbon_trading"]:
            treat = _treat_series(df, pol)
            py = _pilot_year_series(df, pol)
            post = ((df["year"] >= py) & (treat == 1)).astype(int)
            dtmp = df_gd.copy()
            dtmp[f"{pol}_did_tmp"] = (treat * post).astype(int)
            coef, se, n = twfe_cluster(
                dtmp.dropna(subset=["good_day_share"]),
                formula=f"good_day_share ~ {pol}_did_tmp + C(city_code6) + C(year)",
                coef=f"{pol}_did_tmp",
            )
            aq_rows.append({"spec": f"Good-day share ~ {pol} DID", "coef": coef, "se": se, "n": n})

        aq_tbl = pd.DataFrame(aq_rows)
        aq_tbl["coef"] = aq_tbl["coef"].map(lambda x: f"{x:.4f}")
        aq_tbl["se"] = aq_tbl["se"].map(lambda x: f"{x:.4f}")
        regression_tables.append(("Air Quality (Good-day Share)", aq_tbl))

        # Event studies
        lc_g = sun_abraham_event_study(df_gd, policy="lowcarbon", outcome="good_day_share", window=list(range(-8, 7)))
        lc_g_fig = fig_dir / "event_lowcarbon_good_day_share.png"
        write_event_study_plot(lc_g, lc_g_fig, title="Low-carbon pilot → Good-day share (city×year)")
        summaries.append(("Air Quality (Main): Low-carbon pilot", lc_g, lc_g_fig))

        ct_g = sun_abraham_event_study(df_gd, policy="carbon_trading", outcome="good_day_share", window=list(range(-6, 7)))
        ct_g_fig = fig_dir / "event_carbon_trading_good_day_share.png"
        write_event_study_plot(ct_g, ct_g_fig, title="Carbon-trading pilot → Good-day share (city×year)")
        summaries.append(("Air Quality (Robustness): Carbon-trading pilot", ct_g, ct_g_fig))

    # AQI: only years with valid AQI.
    df_aqi = df.dropna(subset=["aqi_mean"]).copy()
    df_aqi = df_aqi[df_aqi["year"] >= 2014]
    # Carbon trading: cohorts 2013/2014/2016, limited leads due to sample start.
    ct_aqi = sun_abraham_event_study(df_aqi, policy="carbon_trading", outcome="aqi_mean", window=list(range(-2, 8)))
    ct_aqi_fig = fig_dir / "event_carbon_trading_aqi_mean.png"
    write_event_study_plot(ct_aqi, ct_aqi_fig, title="Carbon-trading pilot → AQI mean (city×year, 2014+)")
    summaries.append(("AQI (Secondary): Carbon-trading pilot", ct_aqi, ct_aqi_fig))

    # Low-carbon AQI: focus on 2017 cohort vs never-treated to avoid 'already-treated' cohorts in 2014+ sample.
    # Implement by zeroing out treat for cohorts 2010/2012 in this subsample.
    df_lc_2017 = df_aqi.copy()
    pilot = _pilot_year_series(df_lc_2017, "lowcarbon")
    treat = _treat_series(df_lc_2017, "lowcarbon")
    cohort = np.where(treat == 1, pilot, 0)
    cohort = pd.Series(cohort, index=df_lc_2017.index).fillna(0).astype(int)
    df_lc_2017["lowcarbon_treat"] = ((cohort == 2017).astype(int)).astype("Int64")
    df_lc_2017["lowcarbon_pilot_year"] = np.where(cohort == 2017, 2017, np.nan)
    lc_aqi = sun_abraham_event_study(df_lc_2017, policy="lowcarbon", outcome="aqi_mean", window=list(range(-3, 7)))
    lc_aqi_fig = fig_dir / "event_lowcarbon2017_aqi_mean.png"
    write_event_study_plot(lc_aqi, lc_aqi_fig, title="Low-carbon 2017 cohort → AQI mean (city×year, 2014+)")
    summaries.append(("AQI (Secondary): Low-carbon pilot (2017 cohort only)", lc_aqi, lc_aqi_fig))

    # Heterogeneity: baseline industrial share (pre-treatment) × low-carbon DID.
    if "share_secondary_pct" in df.columns and "lowcarbon_treat" in df.columns:
        dhet = df.copy()
        dhet["share_secondary_pct"] = pd.to_numeric(dhet["share_secondary_pct"], errors="coerce")
        # Pre-window chosen to precede first low-carbon cohort (2010).
        base = (
            dhet[dhet["year"].between(2007, 2009)]
            .groupby("city_code6")["share_secondary_pct"]
            .mean()
            .rename("base_secondary_share")
        )
        dhet = dhet.merge(base, on="city_code6", how="left")
        dhet = dhet.dropna(subset=["base_secondary_share"]).copy()
        base_std = float(dhet["base_secondary_share"].std(ddof=0))
        if base_std > 0:
            dhet["base_secondary_share_z"] = (dhet["base_secondary_share"] - float(dhet["base_secondary_share"].mean())) / base_std

            # Build lowcarbon DID from pilot year (works even if *_did is absent/floaty).
            treat = _treat_series(dhet, "lowcarbon")
            py = _pilot_year_series(dhet, "lowcarbon")
            post = ((dhet["year"] >= py) & (treat == 1)).astype(int)
            dhet["lowcarbon_did_tmp"] = (treat * post).astype(int)

            rows_h = []
            b, se, n = twfe_interaction_only(
                dhet.dropna(subset=["log_co2"]),
                y="log_co2",
                did_col="lowcarbon_did_tmp",
                exposure_col="base_secondary_share_z",
            )
            rows_h.append({"spec": "log(CO2) ~ lowcarbon DID × baseline secondary share (z)", "coef": b, "se": se, "n": n})

            if "log_co2_per_gdp" in dhet.columns and dhet["log_co2_per_gdp"].notna().any():
                b, se, n = twfe_interaction_only(
                    dhet.dropna(subset=["log_co2_per_gdp"]),
                    y="log_co2_per_gdp",
                    did_col="lowcarbon_did_tmp",
                    exposure_col="base_secondary_share_z",
                )
                rows_h.append(
                    {"spec": "log(CO2/GDP) ~ lowcarbon DID × baseline secondary share (z)", "coef": b, "se": se, "n": n}
                )

            het = pd.DataFrame(rows_h)
            het["coef"] = het["coef"].map(lambda x: f"{x:.4f}")
            het["se"] = het["se"].map(lambda x: f"{x:.4f}")
            regression_tables.append(("Heterogeneity (Industrial Share)", het))

            # Grouped event studies (high vs low baseline secondary share) on good-day share.
            if "good_day_share" in dhet.columns and dhet["good_day_share"].notna().any():
                dhet = dhet.dropna(subset=["good_day_share", "days"]).copy()
                dhet["days"] = pd.to_numeric(dhet["days"], errors="coerce")
                dhet = dhet.dropna(subset=["days"]).copy()
                dhet = dhet[dhet["days"] >= 330].copy()
                med = float(dhet["base_secondary_share"].median())
                dhet["high_secondary"] = (dhet["base_secondary_share"] >= med).astype(int)

                # Grouped TWFE DID on good-day share (easy-to-read contrast).
                rows_gdid = []
                for label, mask in [("lowsec", dhet["high_secondary"] == 0), ("highsec", dhet["high_secondary"] == 1)]:
                    dg = dhet[mask].copy()
                    treat = _treat_series(dg, "lowcarbon")
                    py = _pilot_year_series(dg, "lowcarbon")
                    post = ((dg["year"] >= py) & (treat == 1)).astype(int)
                    dg["lowcarbon_did_tmp"] = (treat * post).astype(int)
                    coef, se, n = twfe_cluster(
                        dg.dropna(subset=["good_day_share"]),
                        formula="good_day_share ~ lowcarbon_did_tmp + C(city_code6) + C(year)",
                        coef="lowcarbon_did_tmp",
                    )
                    rows_gdid.append({"spec": f"Good-day share ~ lowcarbon DID ({label})", "coef": coef, "se": se, "n": n})
                if rows_gdid:
                    gt = pd.DataFrame(rows_gdid)
                    gt["coef"] = gt["coef"].map(lambda x: f"{x:.4f}")
                    gt["se"] = gt["se"].map(lambda x: f"{x:.4f}")
                    regression_tables.append(("Air Quality (Grouped DID by Industrial Share)", gt))

                for label, mask in [("lowsec", dhet["high_secondary"] == 0), ("highsec", dhet["high_secondary"] == 1)]:
                    dg = dhet[mask].copy()
                    # Needs original policy vars; use lowcarbon_treat/pilot_year already in panel
                    try:
                        es = sun_abraham_event_study(dg, policy="lowcarbon", outcome="good_day_share", window=list(range(-8, 7)))
                    except Exception:
                        continue
                    fig = fig_dir / f"event_lowcarbon_good_day_share_{label}.png"
                    write_event_study_plot(
                        es, fig, title=f"Low-carbon pilot → Good-day share ({label}; baseline secondary share split)"
                    )
                    summaries.append((f"Air Quality (Heterogeneity): Low-carbon pilot ({label})", es, fig))

    # Mechanism: gov-wechat low-carbon agenda index (yearly, 2013+ coverage).
    if "wechat_rate_10k" in df.columns and df["wechat_rate_10k"].notna().any():
        df_w = df.dropna(subset=["wechat_rate_10k"]).copy()
        # Main first-stage: low-carbon 2017 cohort (wechat coverage starts 2013).
        df_w_lc = df_w.copy()
        pilot = _pilot_year_series(df_w_lc, "lowcarbon")
        treat = _treat_series(df_w_lc, "lowcarbon")
        cohort = np.where(treat == 1, pilot, 0)
        cohort = pd.Series(cohort, index=df_w_lc.index).fillna(0).astype(int)
        df_w_lc["lowcarbon_treat"] = ((cohort == 2017).astype(int)).astype("Int64")
        df_w_lc["lowcarbon_pilot_year"] = np.where(cohort == 2017, 2017, np.nan)
        lc_w = sun_abraham_event_study(df_w_lc, policy="lowcarbon", outcome="wechat_rate_10k", window=list(range(-3, 7)))
        lc_w_fig = fig_dir / "event_lowcarbon2017_wechat_rate_10k.png"
        write_event_study_plot(lc_w, lc_w_fig, title="Low-carbon 2017 cohort → WeChat low-carbon rate (per 10k posts)")
        summaries.append(("Mechanism: Gov WeChat (2017 low-carbon cohort)", lc_w, lc_w_fig))

        # Carbon trading: cohorts 2013/2014/2016, limited pre.
        ct_w = sun_abraham_event_study(df_w, policy="carbon_trading", outcome="wechat_rate_10k", window=list(range(-2, 8)))
        ct_w_fig = fig_dir / "event_carbon_trading_wechat_rate_10k.png"
        write_event_study_plot(ct_w, ct_w_fig, title="Carbon-trading pilot → WeChat low-carbon rate (per 10k posts)")
        summaries.append(("Mechanism: Gov WeChat (carbon-trading pilot)", ct_w, ct_w_fig))

    # Mechanism regressions: policy -> wechat -> outcomes (same sample, TWFE).
    if "wechat_rate_10k" in df.columns and df["wechat_rate_10k"].notna().any():
        df_we = df.dropna(subset=["wechat_rate_10k"]).copy()
        df_we = df_we[(df_we["year"] >= 2013) & (df_we["year"] <= 2023)].copy()
        df_we = df_we.sort_values(["city_code6", "year"], kind="mergesort")
        df_we["wechat_rate_10k_lag1"] = df_we.groupby("city_code6")["wechat_rate_10k"].shift(1)
        wechat_std = float(df_we["wechat_rate_10k"].std(ddof=0))
        if wechat_std > 0:
            df_we["wechat_z"] = (df_we["wechat_rate_10k"] - float(df_we["wechat_rate_10k"].mean())) / wechat_std
            df_we["wechat_z_lag1"] = df_we.groupby("city_code6")["wechat_z"].shift(1)

        # Build DID terms (derived from pilot year).
        for pol in ["lowcarbon", "carbon_trading"]:
            treat = _treat_series(df_we, pol)
            py = _pilot_year_series(df_we, pol)
            post = ((df_we["year"] >= py) & (treat == 1)).astype(int)
            df_we[f"{pol}_did"] = (treat * post).astype(int)

        rows = []
        b, se, n = twfe_cluster(
            df_we, formula="wechat_rate_10k ~ lowcarbon_did + C(city_code6) + C(year)", coef="lowcarbon_did"
        )
        rows.append({"spec": "WeChat rate (per 10k) ~ lowcarbon DID", "coef": b, "se": se, "n": n})

        b, se, n = twfe_cluster(
            df_we,
            formula="wechat_rate_10k ~ carbon_trading_did + C(city_code6) + C(year)",
            coef="carbon_trading_did",
        )
        rows.append({"spec": "WeChat rate (per 10k) ~ carbon-trading DID", "coef": b, "se": se, "n": n})

        # CO2 on same sample
        b, se, n = twfe_cluster(df_we, formula="log_co2 ~ lowcarbon_did + C(city_code6) + C(year)", coef="lowcarbon_did")
        rows.append({"spec": "log(CO2) ~ lowcarbon DID (wechat sample)", "coef": b, "se": se, "n": n})

        if "wechat_z" in df_we.columns:
            b, se, n = twfe_cluster(
                df_we,
                formula="log_co2 ~ lowcarbon_did + wechat_z + C(city_code6) + C(year)",
                coef="wechat_z",
            )
            rows.append({"spec": "log(CO2) ~ WeChat (z) + lowcarbon DID", "coef": b, "se": se, "n": n})
            df_we_l1 = df_we.dropna(subset=["wechat_z_lag1"]).copy()
            b, se, n = twfe_cluster(
                df_we_l1,
                formula="log_co2 ~ lowcarbon_did + wechat_z_lag1 + C(city_code6) + C(year)",
                coef="wechat_z_lag1",
            )
            rows.append({"spec": "log(CO2) ~ WeChat (z,t-1) + lowcarbon DID", "coef": b, "se": se, "n": n})

        if "log_co2_per_gdp" in df_we.columns and df_we["log_co2_per_gdp"].notna().any():
            df_int = df_we.dropna(subset=["log_co2_per_gdp"]).copy()
            b, se, n = twfe_cluster(
                df_int, formula="log_co2_per_gdp ~ lowcarbon_did + C(city_code6) + C(year)", coef="lowcarbon_did"
            )
            rows.append({"spec": "log(CO2/GDP) ~ lowcarbon DID (wechat sample)", "coef": b, "se": se, "n": n})
            if "wechat_z" in df_int.columns:
                b, se, n = twfe_cluster(
                    df_int,
                    formula="log_co2_per_gdp ~ lowcarbon_did + wechat_z + C(city_code6) + C(year)",
                    coef="wechat_z",
                )
                rows.append({"spec": "log(CO2/GDP) ~ WeChat (z) + lowcarbon DID", "coef": b, "se": se, "n": n})
                df_int_l1 = df_int.dropna(subset=["wechat_z_lag1"]).copy()
                b, se, n = twfe_cluster(
                    df_int_l1,
                    formula="log_co2_per_gdp ~ lowcarbon_did + wechat_z_lag1 + C(city_code6) + C(year)",
                    coef="wechat_z_lag1",
                )
                rows.append({"spec": "log(CO2/GDP) ~ WeChat (z,t-1) + lowcarbon DID", "coef": b, "se": se, "n": n})

        reg = pd.DataFrame(rows)
        reg["coef"] = reg["coef"].map(lambda x: f"{x:.4f}")
        reg["se"] = reg["se"].map(lambda x: f"{x:.4f}")
        regression_tables.append(("Mechanism Regressions (TWFE, clustered SE)", reg))

        # AQI association (2014+ sample where AQI exists)
        if "aqi_mean" in df.columns and df["aqi_mean"].notna().any():
            df_a = df.dropna(subset=["aqi_mean", "wechat_rate_10k"]).copy()
            df_a = df_a[df_a["year"] >= 2014].copy()
            df_a = df_a.sort_values(["city_code6", "year"], kind="mergesort")
            df_a["wechat_rate_10k_lag1"] = df_a.groupby("city_code6")["wechat_rate_10k"].shift(1)
            a_std = float(df_a["wechat_rate_10k"].std(ddof=0))
            if a_std > 0:
                df_a["wechat_z"] = (df_a["wechat_rate_10k"] - float(df_a["wechat_rate_10k"].mean())) / a_std
                df_a["wechat_z_lag1"] = df_a.groupby("city_code6")["wechat_z"].shift(1)
            for pol in ["lowcarbon", "carbon_trading"]:
                treat = _treat_series(df_a, pol)
                py = _pilot_year_series(df_a, pol)
                post = ((df_a["year"] >= py) & (treat == 1)).astype(int)
                df_a[f"{pol}_did"] = (treat * post).astype(int)

            rows2 = []
            b, se, n = twfe_cluster(
                df_a, formula="aqi_mean ~ carbon_trading_did + C(city_code6) + C(year)", coef="carbon_trading_did"
            )
            rows2.append({"spec": "AQI mean ~ carbon-trading DID (wechat sample)", "coef": b, "se": se, "n": n})
            if "wechat_z" in df_a.columns:
                b, se, n = twfe_cluster(
                    df_a,
                    formula="aqi_mean ~ carbon_trading_did + wechat_z + C(city_code6) + C(year)",
                    coef="wechat_z",
                )
                rows2.append({"spec": "AQI mean ~ WeChat (z) + carbon-trading DID", "coef": b, "se": se, "n": n})

                df_a_l1 = df_a.dropna(subset=["wechat_z_lag1"]).copy()
                b, se, n = twfe_cluster(
                    df_a_l1,
                    formula="aqi_mean ~ carbon_trading_did + wechat_z_lag1 + C(city_code6) + C(year)",
                    coef="wechat_z_lag1",
                )
                rows2.append({"spec": "AQI mean ~ WeChat (z,t-1) + carbon-trading DID", "coef": b, "se": se, "n": n})
            reg2 = pd.DataFrame(rows2)
            reg2["coef"] = reg2["coef"].map(lambda x: f"{x:.4f}")
            reg2["se"] = reg2["se"].map(lambda x: f"{x:.4f}")
            regression_tables.append(("AQI Regressions (TWFE, clustered SE)", reg2))

        # Tie mechanism to the main air-quality outcome (good-day share).
        if "good_day_share" in df.columns and df["good_day_share"].notna().any():
            # Carbon-trading treated cities have very limited pre-coverage in the wechat-observed subsample;
            # use low-carbon 2017 cohort for a cleaner (pre/post) time-ordering check.
            dgd = df_we.dropna(subset=["good_day_share"]).copy()
            pilot = _pilot_year_series(dgd, "lowcarbon")
            treat = _treat_series(dgd, "lowcarbon")
            cohort = np.where(treat == 1, pilot, 0)
            cohort = pd.Series(cohort, index=dgd.index).fillna(0).astype(int)
            dgd["lc2017_treat"] = (cohort == 2017).astype(int)
            dgd["lc2017_post"] = ((dgd["year"] >= 2017) & (dgd["lc2017_treat"] == 1)).astype(int)
            dgd["lc2017_did"] = dgd["lc2017_treat"] * dgd["lc2017_post"]

            rows3 = []
            b, se, n = twfe_cluster(
                dgd,
                formula="good_day_share ~ lc2017_did + C(city_code6) + C(year)",
                coef="lc2017_did",
            )
            rows3.append({"spec": "Good-day share ~ lowcarbon DID (2017 cohort; wechat sample)", "coef": b, "se": se, "n": n})

            if "wechat_z" in dgd.columns:
                b, se, n = twfe_cluster(
                    dgd,
                    formula="good_day_share ~ lc2017_did + wechat_z + C(city_code6) + C(year)",
                    coef="wechat_z",
                )
                rows3.append({"spec": "Good-day share ~ WeChat (z) + lowcarbon DID (2017 cohort)", "coef": b, "se": se, "n": n})

                dgd_l1 = dgd.dropna(subset=["wechat_z_lag1"]).copy()
                b, se, n = twfe_cluster(
                    dgd_l1,
                    formula="good_day_share ~ lc2017_did + wechat_z_lag1 + C(city_code6) + C(year)",
                    coef="wechat_z_lag1",
                )
                rows3.append({"spec": "Good-day share ~ WeChat (z,t-1) + lowcarbon DID (2017 cohort)", "coef": b, "se": se, "n": n})

            reg3 = pd.DataFrame(rows3)
            reg3["coef"] = reg3["coef"].map(lambda x: f"{x:.4f}")
            reg3["se"] = reg3["se"].map(lambda x: f"{x:.4f}")
            regression_tables.append(("Good-day Share (Wechat Mechanism)", reg3))

    # Robustness: city linear trends + excluding municipalities (focus on main outcomes).
    rob_rows = []
    # Build DID terms on full df (works for any panel variant).
    for pol in ["lowcarbon", "carbon_trading"]:
        treat = _treat_series(df, pol)
        py = _pilot_year_series(df, pol)
        post = ((df["year"] >= py) & (treat == 1)).astype(int)
        df[f"{pol}_did_tmp"] = (treat * post).astype(int)

    # CO2 intensity
    if "log_co2_per_gdp" in df.columns and df["log_co2_per_gdp"].notna().any():
        d1 = df.dropna(subset=["log_co2_per_gdp"]).copy()
        b, se, n = twfe_with_city_trends(
            d1, formula_rhs="log_co2_per_gdp ~ lowcarbon_did_tmp + C(city_code6) + C(year)", coef="lowcarbon_did_tmp"
        )
        rob_rows.append({"spec": "log(CO2/GDP) ~ lowcarbon DID + city trends (HC1)", "coef": b, "se": se, "n": n})

        d2 = drop_municipalities(d1)
        b, se, n = twfe_cluster(
            d2, formula="log_co2_per_gdp ~ lowcarbon_did_tmp + C(city_code6) + C(year)", coef="lowcarbon_did_tmp"
        )
        rob_rows.append({"spec": "log(CO2/GDP) ~ lowcarbon DID (drop municipalities)", "coef": b, "se": se, "n": n})

    # Good-day share
    if "good_day_share" in df.columns and df["good_day_share"].notna().any():
        d1 = df.dropna(subset=["good_day_share"]).copy()
        b, se, n = twfe_with_city_trends(
            d1, formula_rhs="good_day_share ~ carbon_trading_did_tmp + C(city_code6) + C(year)", coef="carbon_trading_did_tmp"
        )
        rob_rows.append({"spec": "Good-day share ~ carbon-trading DID + city trends (HC1)", "coef": b, "se": se, "n": n})

        d2 = drop_municipalities(d1)
        b, se, n = twfe_cluster(
            d2, formula="good_day_share ~ carbon_trading_did_tmp + C(city_code6) + C(year)", coef="carbon_trading_did_tmp"
        )
        rob_rows.append({"spec": "Good-day share ~ carbon-trading DID (drop municipalities)", "coef": b, "se": se, "n": n})

    if rob_rows:
        rob = pd.DataFrame(rob_rows)
        rob["coef"] = rob["coef"].map(lambda x: f"{x:.4f}")
        rob["se"] = rob["se"].map(lambda x: f"{x:.4f}")
        regression_tables.append(("Robustness (Trends / Exclusions)", rob))

    write_markdown_report(report_path, panel_path=panel_path, summaries=summaries, regression_tables=regression_tables)
    print(f"Wrote report: {report_path}")
    print(f"Wrote figures under: {fig_dir}")


if __name__ == "__main__":
    main()
